/*
 * Scala.js (https://www.scala-js.org/)
 *
 * Copyright EPFL.
 *
 * Licensed under Apache License 2.0
 * (https://www.apache.org/licenses/LICENSE-2.0).
 *
 * See the NOTICE file distributed with this work for
 * additional information regarding copyright ownership.
 */

package org.scalajs.linker.backend.wasmemitter

import scala.annotation.switch

import org.scalajs.ir.Position
import org.scalajs.ir.Printers._
import org.scalajs.ir.Transformers._
import org.scalajs.ir.Traversers._
import org.scalajs.ir.Trees._
import org.scalajs.ir.Types._

import org.scalajs.linker.backend.webassembly.{Instructions => wa}

/** Transients generated by the optimizer that only makes sense in Wasm. */
object WasmTransients {

  /** Wasm unary op.
   *
   *  Wasm features a number of dedicated opcodes for operations that are not
   *  in the IR, but only implemented in user space. We can see `WasmUnaryOp`
   *  as an extension of `ir.Trees.UnaryOp` that covers those.
   *
   *  Wasm unary ops always preserve pureness.
   */
  final case class WasmUnaryOp(op: WasmUnaryOp.Code, lhs: Tree)
      extends Transient.Value {
    import WasmUnaryOp._

    val tpe: Type = resultTypeOf(op)

    def traverse(traverser: Traverser): Unit =
      traverser.traverse(lhs)

    def transform(transformer: Transformer)(implicit pos: Position): Tree =
      Transient(WasmUnaryOp(op, transformer.transform(lhs)))

    def wasmInstr: wa.SimpleInstr = (op: @switch) match {
      case I32Clz    => wa.I32Clz
      case I32Ctz    => wa.I32Ctz
      case I32Popcnt => wa.I32Popcnt

      case I64Clz    => wa.I64Clz
      case I64Ctz    => wa.I64Ctz
      case I64Popcnt => wa.I64Popcnt

      case F32Abs => wa.F32Abs

      case F64Abs     => wa.F64Abs
      case F64Ceil    => wa.F64Ceil
      case F64Floor   => wa.F64Floor
      case F64Nearest => wa.F64Nearest
      case F64Sqrt    => wa.F64Sqrt

      case I32ReinterpretF32 => wa.I32ReinterpretF32
      case I64ReinterpretF64 => wa.I64ReinterpretF64
      case F32ReinterpretI32 => wa.F32ReinterpretI32
      case F64ReinterpretI64 => wa.F64ReinterpretI64
    }

    def printIR(out: IRTreePrinter): Unit = {
      out.print("$")
      out.print(wasmInstr.mnemonic)
      out.printArgs(List(lhs))
    }
  }

  object WasmUnaryOp {
    /** Codes are raw Ints to be able to write switch matches on them. */
    type Code = Int

    final val I32Clz = 1
    final val I32Ctz = 2
    final val I32Popcnt = 3

    final val I64Clz = 4
    final val I64Ctz = 5
    final val I64Popcnt = 6

    final val F32Abs = 7

    final val F64Abs = 8
    final val F64Ceil = 9
    final val F64Floor = 10
    final val F64Nearest = 11
    final val F64Sqrt = 12

    final val I32ReinterpretF32 = 13
    final val I64ReinterpretF64 = 14
    final val F32ReinterpretI32 = 15
    final val F64ReinterpretI64 = 16

    def resultTypeOf(op: Code): Type = (op: @switch) match {
      case I32Clz | I32Ctz | I32Popcnt | I32ReinterpretF32 =>
        IntType

      case I64Clz | I64Ctz | I64Popcnt | I64ReinterpretF64 =>
        LongType

      case F32Abs | F32ReinterpretI32 =>
        FloatType

      case F64Abs | F64Ceil | F64Floor | F64Nearest | F64Sqrt | F64ReinterpretI64 =>
        DoubleType
    }
  }

  /** Wasm binary op.
   *
   *  Wasm features a number of dedicated opcodes for operations that are not
   *  in the IR, but only implemented in user space. We can see `WasmBinaryOp`
   *  as an extension of `ir.Trees.BinaryOp` that covers those.
   *
   *  Unsigned divisions and remainders exhibit always-unchecked undefined
   *  behavior when their rhs is 0. It is up to code generating those transient
   *  nodes to check for 0 themselves if necessary.
   *
   *  All other Wasm binary ops preserve pureness.
   */
  final case class WasmBinaryOp(op: WasmBinaryOp.Code, lhs: Tree, rhs: Tree)
      extends Transient.Value {
    import WasmBinaryOp._

    val tpe: Type = resultTypeOf(op)

    def traverse(traverser: Traverser): Unit = {
      traverser.traverse(lhs)
      traverser.traverse(rhs)
    }

    def transform(transformer: Transformer)(implicit pos: Position): Tree = {
      Transient(WasmBinaryOp(op, transformer.transform(lhs),
          transformer.transform(rhs)))
    }

    def wasmInstr: wa.SimpleInstr = (op: @switch) match {
      case I32GtU => wa.I32GtU

      case I32DivU => wa.I32DivU
      case I32RemU => wa.I32RemU
      case I32Rotl => wa.I32Rotl
      case I32Rotr => wa.I32Rotr

      case I64DivU => wa.I64DivU
      case I64RemU => wa.I64RemU
      case I64Rotl => wa.I64Rotl
      case I64Rotr => wa.I64Rotr

      case F32Min => wa.F32Min
      case F32Max => wa.F32Max

      case F64Min => wa.F64Min
      case F64Max => wa.F64Max
    }

    def printIR(out: IRTreePrinter): Unit = {
      out.print("$")
      out.print(wasmInstr.mnemonic)
      out.printArgs(List(lhs, rhs))
    }
  }

  object WasmBinaryOp {
    /** Codes are raw Ints to be able to write switch matches on them. */
    type Code = Int

    final val I32GtU = 1

    final val I32DivU = 2
    final val I32RemU = 3
    final val I32Rotl = 4
    final val I32Rotr = 5

    final val I64DivU = 6
    final val I64RemU = 7
    final val I64Rotl = 8
    final val I64Rotr = 9

    final val F32Min = 10
    final val F32Max = 11

    final val F64Min = 12
    final val F64Max = 13

    def resultTypeOf(op: Code): Type = (op: @switch) match {
      case I32GtU =>
        BooleanType

      case I32DivU | I32RemU | I32Rotl | I32Rotr =>
        IntType

      case I64DivU | I64RemU | I64Rotl | I64Rotr =>
        LongType

      case F32Min | F32Max =>
        FloatType

      case F64Min | F64Max =>
        DoubleType
    }
  }

  /** Wasm intrinsic for `jl.Character.toString(int)`.
   *
   *  Typing rules: `codePoint` must be an `int`. The result is a `string`.
   *
   *  Evaluation semantics are as follows:
   *
   *  1. Let `codePointV` be the result of evaluating `codePoint`.
   *  2. If `codePointV` is not a valid code point, UB (i.e., this transient
   *     *assumes* that `codePointV` is a valid code point).
   *  3. Return a string of 1 or 2 chars that represents the given code point.
   */
  final case class WasmStringFromCodePoint(codePoint: Tree)
      extends Transient.Value {

    val tpe: Type = StringType

    def traverse(traverser: Traverser): Unit = {
      traverser.traverse(codePoint)
    }

    def transform(transformer: Transformer)(implicit pos: Position): Tree =
      Transient(WasmStringFromCodePoint(transformer.transform(codePoint)))

    def printIR(out: IRTreePrinter): Unit = {
      out.print("$stringFromCodePoint")
      out.printArgs(List(codePoint))
    }
  }

  /** Wasm intrinsic for `jl.String.codePointAt`.
   *
   *  Typing rules: `string` must be a `jl.String`; `index` must be an `int`.
   *  The result is an `int`.
   *
   *  Evaluation semantics are as follows:
   *
   *  1. Let `stringV` be the result of evaluating `string`. If it is `null`,
   *     throw an NPE (subject to UB).
   *  2. Let `indexV` be the result of evaluating `index`.
   *  3. If `indexV < 0` or `indexV >= stringV.length`, throw a
   *     `StringIndexOutOfBoundsException` (subject to UB).
   *  4. Return the code point starting at index `indexV` of `stringV`.
   */
  final case class WasmCodePointAt(string: Tree, index: Tree)
      extends Transient.Value {

    val tpe: Type = IntType

    def traverse(traverser: Traverser): Unit = {
      traverser.traverse(string)
      traverser.traverse(index)
    }

    def transform(transformer: Transformer)(implicit pos: Position): Tree = {
      Transient(WasmCodePointAt(transformer.transform(string),
          transformer.transform(index)))
    }

    def printIR(out: IRTreePrinter): Unit = {
      out.print("$codePointAt")
      out.printArgs(List(string, index))
    }
  }

  /** Wasm intrinsic for `jl.String.substring`.
   *
   *  Typing rules: `string` must be a `jl.String`; `start` and `optEnd` must
   *  be `int`s. The result is a `string`.
   *
   *  Evaluation semantics are as follows:
   *
   *  1. Let `stringV` be the result of evaluating `string`. If it is `null`,
   *     throw an NPE (subject to UB).
   *  2. Let `startV` be the result of evaluating `start`.
   *  3. If `optEnd` is empty, let `endV` be `stringV.length`. Otherwise, let
   *     `endV` be the result of evaluating `optEnd`.
   *  4. If `startV < 0`, `endV < startV` or `endV > stringV.length`, throw a
   *     `StringIndexOutOfBoundsException` (subject to UB).
   *  5. Return the substring of `stringV` in the range `[startV, endV)`.
   */
  final case class WasmSubstring(string: Tree, start: Tree, optEnd: Option[Tree])
      extends Transient.Value {

    val tpe: Type = StringType

    def traverse(traverser: Traverser): Unit = {
      traverser.traverse(string)
      traverser.traverse(start)
      optEnd.foreach(traverser.traverse(_))
    }

    def transform(transformer: Transformer)(implicit pos: Position): Tree = {
      Transient(WasmSubstring(transformer.transform(string),
          transformer.transform(start), transformer.transformTreeOpt(optEnd)))
    }

    def printIR(out: IRTreePrinter): Unit = {
      out.print("$substring")
      out.printArgs(string :: start :: optEnd.toList)
    }
  }
}
